{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import os.path\n",
    "import helper_cityscapes as helper\n",
    "import warnings\n",
    "import scipy.misc\n",
    "import tensorflow as tf\n",
    "from datetime import timedelta\n",
    "from distutils.version import LooseVersion\n",
    "import project_tests as tests\n",
    "\n",
    "# Check TensorFlow Version\n",
    "assert LooseVersion(tf.__version__) >= LooseVersion('1.0'), 'Please use TensorFlow version 1.0 or newer.  You are using {}'.format(tf.__version__)\n",
    "print('TensorFlow Version: {}'.format(tf.__version__))\n",
    "\n",
    "# Check for a GPU\n",
    "if not tf.test.gpu_device_name():\n",
    "    warnings.warn('No GPU found. Please use a GPU to train your neural network.')\n",
    "else:\n",
    "    print('Default GPU Device: {}'.format(tf.test.gpu_device_name()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "L2_REG = 1e-6\n",
    "STDEV = 1e-3\n",
    "KEEP_PROB = 0.5\n",
    "LEARNING_RATE = 1e-4\n",
    "EPOCHS = 30\n",
    "BATCH_SIZE = 16\n",
    "IMAGE_SHAPE = (256, 512)\n",
    "NUM_CLASSES = 3\n",
    "\n",
    "DATA_DIR = './data'\n",
    "RUNS_DIR = './runs_cityscapes'\n",
    "MODEL_DIR = './models_cityscapes'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_vgg(sess, vgg_path):\n",
    "    \"\"\"\n",
    "    Load Pretrained VGG Model into TensorFlow.\n",
    "    :param sess: TensorFlow Session\n",
    "    :param vgg_path: Path to vgg folder, containing \"variables/\" and \"saved_model.pb\"\n",
    "    :return: Tuple of Tensors from VGG model (image_input, keep_prob, layer3_out, layer4_out, layer7_out)\n",
    "    \"\"\"\n",
    "    vgg_tag = 'vgg16'\n",
    "    vgg_input_tensor_name = 'image_input:0'\n",
    "    vgg_keep_prob_tensor_name = 'keep_prob:0'\n",
    "    vgg_layer3_out_tensor_name = 'layer3_out:0'\n",
    "    vgg_layer4_out_tensor_name = 'layer4_out:0'\n",
    "    vgg_layer7_out_tensor_name = 'layer7_out:0'\n",
    "    graph = tf.get_default_graph()\n",
    "    tf.saved_model.loader.load(sess, [vgg_tag], vgg_path)\n",
    "    input = graph.get_tensor_by_name(vgg_input_tensor_name)\n",
    "    keep_prob = graph.get_tensor_by_name(vgg_keep_prob_tensor_name)\n",
    "    layer3 = graph.get_tensor_by_name(vgg_layer3_out_tensor_name)\n",
    "    layer4 = graph.get_tensor_by_name(vgg_layer4_out_tensor_name)\n",
    "    layer7 = graph.get_tensor_by_name(vgg_layer7_out_tensor_name)\n",
    "    return input, keep_prob, layer3, layer4, layer7\n",
    "\n",
    "print(\"Load VGG Model:\")\n",
    "tests.test_load_vgg(load_vgg, tf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def layers(vgg_layer3_out, vgg_layer4_out, vgg_layer7_out, num_classes):\n",
    "    \"\"\"\n",
    "    Create the layers for a fully convolutional network.  Build skip-layers using the vgg layers.\n",
    "    :param vgg_layer7_out: TF Tensor for VGG Layer 3 output\n",
    "    :param vgg_layer4_out: TF Tensor for VGG Layer 4 output\n",
    "    :param vgg_layer3_out: TF Tensor for VGG Layer 7 output\n",
    "    :param num_classes: Number of classes to classify\n",
    "    :return: The Tensor for the last layer of output\n",
    "    \"\"\"\n",
    "    layer7_conv_1x1 = tf.layers.conv2d(vgg_layer7_out, num_classes, 1, 1,\n",
    "                                       padding='same', kernel_initializer=tf.random_normal_initializer(stddev=STDEV),\n",
    "                                       kernel_regularizer=tf.contrib.layers.l2_regularizer(L2_REG))\n",
    "    output = tf.layers.conv2d_transpose(layer7_conv_1x1, num_classes, 4, 2,\n",
    "                                        padding='same', kernel_initializer=tf.random_normal_initializer(stddev=STDEV),\n",
    "                                        kernel_regularizer=tf.contrib.layers.l2_regularizer(L2_REG))\n",
    "    layer4_conv_1x1 = tf.layers.conv2d(vgg_layer4_out, num_classes, 1, 1,\n",
    "                                       padding='same', kernel_initializer=tf.random_normal_initializer(stddev=STDEV),\n",
    "                                       kernel_regularizer=tf.contrib.layers.l2_regularizer(L2_REG))\n",
    "    output = tf.add(output, layer4_conv_1x1)\n",
    "    output = tf.layers.conv2d_transpose(output, num_classes, 4, 2,\n",
    "                                        padding='same', kernel_initializer=tf.random_normal_initializer(stddev=STDEV),\n",
    "                                        kernel_regularizer=tf.contrib.layers.l2_regularizer(L2_REG))\n",
    "    layer3_conv_1x1 = tf.layers.conv2d(vgg_layer3_out, num_classes, 1, 1,\n",
    "                                       padding='same', kernel_initializer=tf.random_normal_initializer(stddev=STDEV),\n",
    "                                       kernel_regularizer=tf.contrib.layers.l2_regularizer(L2_REG))\n",
    "    output = tf.add(output, layer3_conv_1x1)\n",
    "    output = tf.layers.conv2d_transpose(output, num_classes, 16, 8,\n",
    "                                        padding='same', kernel_initializer=tf.random_normal_initializer(stddev=STDEV),\n",
    "                                        kernel_regularizer=tf.contrib.layers.l2_regularizer(L2_REG))   \n",
    "    return output\n",
    "\n",
    "print(\"Layers Test:\")\n",
    "tests.test_layers(layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loss with weights\n",
    "weights = [0.3, 0.6, 0.3]\n",
    "# Classes are unbalanced, that is why we can add some weight to the road class.\n",
    "# From https://github.com/MarvinTeichmann/KittiSeg\n",
    "def optimize(nn_last_layer, correct_label, learning_rate, num_classes):\n",
    "    \"\"\"\n",
    "    Build the TensorFLow loss and optimizer operations.\n",
    "    :param nn_last_layer: TF Tensor of the last layer in the neural network\n",
    "    :param correct_label: TF Placeholder for the correct label image\n",
    "    :param learning_rate: TF Placeholder for the learning rate\n",
    "    :param num_classes: Number of classes to classify\n",
    "    :return: Tuple of (logits, train_op, cross_entropy_loss)\n",
    "    \"\"\"\n",
    "    logits = tf.reshape(nn_last_layer, (-1, num_classes))\n",
    "    labels = tf.reshape(correct_label, (-1, num_classes))\n",
    "    softmax = tf.nn.softmax(logits)\n",
    "    cross_entropy = -tf.reduce_sum(tf.multiply(labels * tf.log(softmax), weights), reduction_indices=[1])\n",
    "    cross_entropy_loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')\n",
    "    train_op = tf.train.RMSPropOptimizer(learning_rate).minimize(cross_entropy_loss)\n",
    "    return logits, train_op, cross_entropy_loss\n",
    "\n",
    "print(\"Optimize Test:\")\n",
    "#tests.test_optimize(optimize)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def train_nn(sess, epochs, batch_size, get_batches_fn, train_op, cross_entropy_loss, input_image,\n",
    "             correct_label, keep_prob, learning_rate, saver, data_dir):\n",
    "    \"\"\"\n",
    "    Train neural network and print out the loss during training.\n",
    "    :param sess: TF Session\n",
    "    :param epochs: Number of epochs\n",
    "    :param batch_size: Batch size\n",
    "    :param get_batches_fn: Function to get batches of training data.  Call using get_batches_fn(batch_size)\n",
    "    :param train_op: TF Operation to train the neural network\n",
    "    :param cross_entropy_loss: TF Tensor for the amount of loss\n",
    "    :param input_image: TF Placeholder for input images\n",
    "    :param correct_label: TF Placeholder for label images\n",
    "    :param keep_prob: TF Placeholder for dropout keep probability\n",
    "    :param learning_rate: TF Placeholder for learning rate\n",
    "    \"\"\"\n",
    "    # TODO: Implement function\n",
    "    for epoch in range(epochs):\n",
    "        s_time = time.time()\n",
    "        for image, targets in get_batches_fn(batch_size):\n",
    "            _, loss = sess.run([train_op, cross_entropy_loss], \n",
    "                feed_dict = {input_image: image, correct_label: targets, keep_prob: KEEP_PROB ,\n",
    "                             learning_rate: LEARNING_RATE }) #/ (epoch/100 + 1)\n",
    "            print(loss)\n",
    "        # Print data on the learning process\n",
    "        print(\"Epoch: {}\".format(epoch + 1), \"/ {}\".format(epochs), \" Loss: {:.3f}\".format(loss), \" Time: \",\n",
    "              str(timedelta(seconds=(time.time() - s_time))))\n",
    "        if (epoch + 1) % 10 == 0: # Save every 10 epochs\n",
    "            save_path = saver.save(sess, os.path.join(data_dir, 'cont_epoch_' + str(epoch) + '.ckpt'))\n",
    "        \n",
    "\n",
    "#Don't use the provided test, as we have a different input to the function\n",
    "#tests.test_train_nn(train_nn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def run():\n",
    "    #tests.test_for_kitti_dataset(DATA_DIR)\n",
    "    # Download pretrained vgg model\n",
    "    #helper.maybe_download_pretrained_vgg(DATA_DIR)\n",
    "\n",
    "    # OPTIONAL: Train and Inference on the cityscapes dataset instead of the Kitti dataset.\n",
    "    # You'll need a GPU with at least 10 teraFLOPS to train on.\n",
    "    #  https://www.cityscapes-dataset.com/\n",
    "    print(\"Start training...\")\n",
    "    config = tf.ConfigProto(gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5))\n",
    "    with tf.Session(config=config) as sess:\n",
    "        # Path to vgg model\n",
    "        vgg_path = os.path.join(DATA_DIR, 'vgg')\n",
    "        # Create function to get batches\n",
    "        get_batches_fn = helper.gen_batch_function(os.path.join(DATA_DIR, 'leftImg8bit'), IMAGE_SHAPE)\n",
    "        # OPTIONAL: Augment Images for better results\n",
    "        #  https://datascience.stackexchange.com/questions/5224/how-to-prepare-augment-images-for-neural-network\n",
    "        # Add some augmentations, see helper.py\n",
    "        input, keep_prob, layer3, layer4, layer7 = load_vgg(sess, vgg_path)\n",
    "        output = layers(layer3, layer4, layer7, NUM_CLASSES)\n",
    "        correct_label = tf.placeholder(dtype = tf.float32, shape = (None, None, None, NUM_CLASSES))\n",
    "        learning_rate = tf.placeholder(dtype = tf.float32)\n",
    "        logits, train_op, cross_entropy_loss = optimize(output, correct_label, learning_rate, NUM_CLASSES)\n",
    "        tf.set_random_seed(123)\n",
    "        sess.run(tf.global_variables_initializer())\n",
    "        saver = tf.train.Saver() #Simple model saver\n",
    "        train_nn(sess, EPOCHS, BATCH_SIZE, get_batches_fn, train_op, cross_entropy_loss, input, correct_label,\n",
    "                 keep_prob, learning_rate,  saver, MODEL_DIR)\n",
    "        # Save inference data using helper.save_inference_samples\n",
    "        helper.save_inference_samples(RUNS_DIR, DATA_DIR, sess, IMAGE_SHAPE, logits, keep_prob, input, NUM_CLASSES)\n",
    "\n",
    "        # OPTIONAL: Apply the trained model to a video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "\n",
    "def save_samples():\n",
    "    with tf.Session() as sess:\n",
    "        # Path to vgg model\n",
    "        vgg_path = os.path.join(DATA_DIR, 'vgg')\n",
    "        # Create function to get batches\n",
    "        get_batches_fn = helper.gen_batch_function(os.path.join(DATA_DIR, 'data_road/training'), IMAGE_SHAPE)\n",
    "        input, keep_prob, layer3, layer4, layer7 = load_vgg(sess, vgg_path)\n",
    "        output = layers(layer3, layer4, layer7, NUM_CLASSES)\n",
    "        correct_label = tf.placeholder(dtype = tf.float32, shape = (None, None, None, NUM_CLASSES))\n",
    "        learning_rate = tf.placeholder(dtype = tf.float32)\n",
    "        logits, train_op, cross_entropy_loss = optimize(output, correct_label, learning_rate, NUM_CLASSES)\n",
    "        sess.run(tf.global_variables_initializer())\n",
    "        new_saver = tf.train.import_meta_graph('./models_3col/epoch_199.ckpt.meta')\n",
    "        new_saver.restore(sess, tf.train.latest_checkpoint('./models_3col/'))\n",
    "        # Save inference data using helper.save_inference_samples\n",
    "        helper.save_inference_samples(RUNS_DIR, DATA_DIR, sess, IMAGE_SHAPE, logits, keep_prob, input, NUM_CLASSES)\n",
    "        \n",
    "save_samples()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def cont():\n",
    "    with tf.Session() as sess:\n",
    "        # Path to vgg model\n",
    "        vgg_path = os.path.join(DATA_DIR, 'vgg')\n",
    "        # Create function to get batches\n",
    "        get_batches_fn = helper.gen_batch_function(os.path.join(DATA_DIR, 'data_road/training'), IMAGE_SHAPE)\n",
    "        input, keep_prob, layer3, layer4, layer7 = load_vgg(sess, vgg_path)\n",
    "        output = layers(layer3, layer4, layer7, NUM_CLASSES)\n",
    "        correct_label = tf.placeholder(dtype = tf.float32, shape = (None, None, None, NUM_CLASSES))\n",
    "        learning_rate = tf.placeholder(dtype = tf.float32)\n",
    "        logits, train_op, cross_entropy_loss = optimize(output, correct_label, learning_rate, NUM_CLASSES)\n",
    "        sess.run(tf.global_variables_initializer())\n",
    "        new_saver = tf.train.import_meta_graph('./models_3col/epoch_199.ckpt.meta')\n",
    "        new_saver.restore(sess, tf.train.latest_checkpoint('./models_3col/'))\n",
    "        saver = tf.train.Saver() #Simple model saver\n",
    "        train_nn(sess, 10, BATCH_SIZE, get_batches_fn, train_op, cross_entropy_loss, input, correct_label,\n",
    "                 keep_prob, learning_rate,  saver, MODEL_DIR)\n",
    "        helper.save_inference_samples(RUNS_DIR, DATA_DIR, sess, IMAGE_SHAPE, logits, keep_prob, input, NUM_CLASSES)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "cont()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
